{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 0
   },
   "source": [
    "# The Dataset for Pretraining Word Embedding\n",
    ":label:`sec_word2vec_data`\n",
    "\n",
    "In this section, we will introduce how to preprocess a dataset with\n",
    "negative sampling :numref:`sec_approx_train` and load into minibatches for\n",
    "word2vec training. The dataset we use is [Penn Tree Bank (PTB)]( https://catalog.ldc.upenn.edu/LDC99T42), which is a small but commonly-used corpus. It takes samples from Wall Street Journal articles and includes training sets, validation sets, and test sets.\n",
    "\n",
    "First, import the packages and modules required for the experiment.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%load ../utils/djl-imports\n",
    "%load ../utils/plot-utils\n",
    "%load ../utils/Functions.java\n",
    "%load ../utils/PlotUtils.java\n",
    "\n",
    "%load ../utils/StopWatch.java\n",
    "%load ../utils/Accumulator.java\n",
    "%load ../utils/Animator.java\n",
    "%load ../utils/Training.java\n",
    "%load ../utils/timemachine/Vocab.java"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import java.util.stream.*;\n",
    "import org.apache.commons.math3.distribution.EnumeratedDistribution;"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDManager manager = NDManager.newBaseManager();"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 3
   },
   "source": [
    "## Reading and Preprocessing the Dataset\n",
    "\n",
    "This dataset has already been preprocessed. Each line of the dataset acts as a sentence. All the words in a sentence are separated by spaces. In the word embedding task, each word is a token.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static String[][] readPTB() throws IOException {\n",
    "    String ptbURL = \"http://d2l-data.s3-accelerate.amazonaws.com/ptb.zip\";\n",
    "    InputStream input = new URL(ptbURL).openStream();\n",
    "    ZipUtils.unzip(input, Paths.get(\"./\"));\n",
    "\n",
    "    ArrayList<String> lines = new ArrayList<>();\n",
    "    File file = new File(\"./ptb/ptb.train.txt\");\n",
    "    Scanner myReader = new Scanner(file);\n",
    "    while (myReader.hasNextLine()) {\n",
    "        lines.add(myReader.nextLine());\n",
    "    }\n",
    "    String[][] tokens = new String[lines.size()][];\n",
    "    for (int i = 0; i < lines.size(); i++) {\n",
    "        tokens[i] = lines.get(i).trim().split(\" \");\n",
    "    }\n",
    "    return tokens;\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "String[][] sentences = readPTB();\n",
    "System.out.println(\"# sentences: \" + sentences.length);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 5
   },
   "source": [
    "Next we build a vocabulary with words appeared not greater than 10 times mapped into a \"&lt;unk&gt;\" token. Note that the preprocessed PTB data also contains \"&lt;unk&gt;\" tokens presenting rare words.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Vocab vocab = new Vocab(sentences, 10, new String[] {});\n",
    "System.out.println(vocab.length());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 7
   },
   "source": [
    "## Subsampling\n",
    "\n",
    "In text data, there are generally some words that appear at high frequencies, such \"the\", \"a\", and \"in\" in English. Generally speaking, in a context window, it is better to train the word embedding model when a word (such as \"chip\") and a lower-frequency word (such as \"microprocessor\") appear at the same time, rather than when a word appears with a higher-frequency word (such as \"the\"). Therefore, when training the word embedding model, we can perform subsampling on the words :cite:`Mikolov.Sutskever.Chen.ea.2013`. Specifically, each indexed word $w_i$ in the dataset will drop out at a certain probability. The dropout probability is given as:\n",
    "\n",
    "$$ P(w_i) = \\max\\left(1 - \\sqrt{\\frac{t}{f(w_i)}}, 0\\right),$$\n",
    "\n",
    "Here, $f(w_i)$ is the ratio of the instances of word $w_i$ to the total number of words in the dataset, and the constant $t$ is a hyperparameter (set to $10^{-4}$ in this experiment). As we can see, it is only possible to drop out the word $w_i$ in subsampling when $f(w_i) > t$. The higher the word's frequency, the higher its dropout probability.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static boolean keep(String token, LinkedHashMap<?, Integer> counter, int numTokens) {\n",
    "    // Return True if to keep this token during subsampling\n",
    "    return new Random().nextFloat() < Math.sqrt(1e-4 / counter.get(token) * numTokens);\n",
    "}\n",
    "\n",
    "public static String[][] subSampling(String[][] sentences, Vocab vocab) {\n",
    "    for (int i = 0; i < sentences.length; i++) {\n",
    "        for (int j = 0; j < sentences[i].length; j++) {\n",
    "            sentences[i][j] = vocab.idxToToken.get(vocab.getIdx(sentences[i][j]));\n",
    "        }\n",
    "    }\n",
    "    // Count the frequency for each word\n",
    "    LinkedHashMap<?, Integer> counter = vocab.countCorpus2D(sentences);\n",
    "    int numTokens = 0;\n",
    "    for (Integer value : counter.values()) {\n",
    "        numTokens += value;\n",
    "    }\n",
    "\n",
    "    // Now do the subsampling\n",
    "    String[][] output = new String[sentences.length][];\n",
    "    for (int i = 0; i < sentences.length; i++) {\n",
    "        ArrayList<String> tks = new ArrayList<>();\n",
    "        for (int j = 0; j < sentences[i].length; j++) {\n",
    "            String tk = sentences[i][j];\n",
    "            if (keep(sentences[i][j], counter, numTokens)) {\n",
    "                tks.add(tk);\n",
    "            }\n",
    "        }\n",
    "        output[i] = tks.toArray(new String[tks.size()]);\n",
    "    }\n",
    "\n",
    "    return output;\n",
    "}\n",
    "\n",
    "String[][] subsampled = subSampling(sentences, vocab);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 9
   },
   "source": [
    "Compare the sequence lengths before and after sampling, we can see subsampling significantly reduced the sequence length.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "double[] y1 = new double[sentences.length];\n",
    "for (int i = 0; i < sentences.length; i++) y1[i] = sentences[i].length;\n",
    "double[] y2 = new double[subsampled.length];\n",
    "for (int i = 0; i < subsampled.length; i++) y2[i] = subsampled[i].length;\n",
    "\n",
    "HistogramTrace trace1 =\n",
    "        HistogramTrace.builder(y1).opacity(.75).name(\"origin\").nBinsX(20).build();\n",
    "HistogramTrace trace2 =\n",
    "        HistogramTrace.builder(y2).opacity(.75).name(\"subsampled\").nBinsX(20).build();\n",
    "\n",
    "Layout layout =\n",
    "        Layout.builder()\n",
    "                .barMode(Layout.BarMode.GROUP)\n",
    "                .showLegend(true)\n",
    "                .xAxis(Axis.builder().title(\"# tokens per sentence\").build())\n",
    "                .yAxis(Axis.builder().title(\"count\").build())\n",
    "                .build();\n",
    "new Figure(layout, trace1, trace2);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 11
   },
   "source": [
    "For individual tokens, the sampling rate of the high-frequency word \"the\" is less than 1/20.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static String compareCounts(String token, String[][] sentences, String[][] subsampled) {\n",
    "    int beforeCount = 0;\n",
    "    for (int i = 0; i < sentences.length; i++) {\n",
    "        for (int j = 0; j < sentences[i].length; j++) {\n",
    "            if (sentences[i][j].equals(token)) beforeCount += 1;\n",
    "        }\n",
    "    }\n",
    "\n",
    "    int afterCount = 0;\n",
    "    for (int i = 0; i < subsampled.length; i++) {\n",
    "        for (int j = 0; j < subsampled[i].length; j++) {\n",
    "            if (subsampled[i][j].equals(token)) afterCount += 1;\n",
    "        }\n",
    "    }\n",
    "\n",
    "    return \"# of \\\"the\\\": before=\" + beforeCount + \", after=\" + afterCount;\n",
    "}\n",
    "\n",
    "System.out.println(compareCounts(\"the\", sentences, subsampled));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 13
   },
   "source": [
    "But the low-frequency word \"join\" is completely preserved.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "System.out.println(compareCounts(\"join\", sentences, subsampled));"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 15
   },
   "source": [
    "Last, we map each token into an index to construct the corpus.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Integer[][] corpus = new Integer[subsampled.length][];\n",
    "for (int i = 0; i < subsampled.length; i++) {\n",
    "    corpus[i] = vocab.getIdxs(subsampled[i]);\n",
    "}\n",
    "for (int i = 0; i < 3; i++) {\n",
    "    System.out.println(Arrays.toString(corpus[i]));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 17
   },
   "source": [
    "## Loading the Dataset\n",
    "\n",
    "Next we read the corpus with token indicies into data batches for training.\n",
    "\n",
    "### Extracting Central Target Words and Context Words\n",
    "\n",
    "We use words with a distance from the central target word not exceeding the context window size as the context words of the given center target word. The following definition function extracts all the central target words and their context words. It uniformly and randomly samples an integer to be used as the context window size between integer 1 and the `maxWindowSize` (maximum context window).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static Pair<ArrayList<Integer>, ArrayList<ArrayList<Integer>>> getCentersAndContext(\n",
    "        Integer[][] corpus, int maxWindowSize) {\n",
    "    ArrayList<Integer> centers = new ArrayList<>();\n",
    "    ArrayList<ArrayList<Integer>> contexts = new ArrayList<>();\n",
    "\n",
    "    for (Integer[] line : corpus) {\n",
    "        // Each sentence needs at least 2 words to form a \"central target word\n",
    "        // - context word\" pair\n",
    "        if (line.length < 2) {\n",
    "            continue;\n",
    "        }\n",
    "        centers.addAll(Arrays.asList(line));\n",
    "        for (int i = 0; i < line.length; i++) { // Context window centered at i\n",
    "            int windowSize = new Random().nextInt(maxWindowSize - 1) + 1;\n",
    "            List<Integer> indices =\n",
    "                    IntStream.range(\n",
    "                                    Math.max(0, i - windowSize),\n",
    "                                    Math.min(line.length, i + 1 + windowSize))\n",
    "                            .boxed()\n",
    "                            .collect(Collectors.toList());\n",
    "            // Exclude the central target word from the context words\n",
    "            indices.remove(indices.indexOf(i));\n",
    "            ArrayList<Integer> context = new ArrayList<>();\n",
    "            for (Integer idx : indices) {\n",
    "                context.add(line[idx]);\n",
    "            }\n",
    "            contexts.add(context);\n",
    "        }\n",
    "    }\n",
    "    return new Pair<>(centers, contexts);\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 19
   },
   "source": [
    "Next, we create an artificial dataset containing two sentences of 7 and 3 words, respectively. Assume the maximum context window is 2 and print all the central target words and their context words.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Integer[][] tinyDataset =\n",
    "        new Integer[][] {\n",
    "            IntStream.range(0, 7)\n",
    "                    .boxed()\n",
    "                    .collect(Collectors.toList())\n",
    "                    .toArray(new Integer[] {}),\n",
    "            IntStream.range(7, 10)\n",
    "                    .boxed()\n",
    "                    .collect(Collectors.toList())\n",
    "                    .toArray(new Integer[] {})\n",
    "        };\n",
    "\n",
    "System.out.println(\"dataset \" + Arrays.deepToString(tinyDataset));\n",
    "Pair<ArrayList<Integer>, ArrayList<ArrayList<Integer>>> centerContextPair =\n",
    "        getCentersAndContext(tinyDataset, 2);\n",
    "for (int i = 0; i < centerContextPair.getValue().size(); i++) {\n",
    "    System.out.println(\n",
    "            \"Center \"\n",
    "                    + centerContextPair.getKey().get(i)\n",
    "                    + \" has contexts\"\n",
    "                    + centerContextPair.getValue().get(i));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 21
   },
   "source": [
    "We set the maximum context window size to 5. The following extracts all the central target words and their context words in the dataset.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "centerContextPair = getCentersAndContext(corpus, 5);\n",
    "ArrayList<Integer> allCenters = centerContextPair.getKey();\n",
    "ArrayList<ArrayList<Integer>> allContexts = centerContextPair.getValue();\n",
    "System.out.println(\"# center-context pairs:\" + allCenters.size());"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 23
   },
   "source": [
    "### Negative Sampling\n",
    "\n",
    "We use negative sampling for approximate training. For a central and context word pair, we randomly sample $K$ noise words ($K=5$ in the experiment). According to the suggestion in the Word2vec paper, the noise word sampling probability $P(w)$ is the ratio of the word frequency of $w$ to the total word frequency raised to the power of 0.75 :cite:`Mikolov.Sutskever.Chen.ea.2013`.\n",
    "\n",
    "We first define a class to draw a candidate according to the sampling weights. It caches a 10,000 size random number bank.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public class RandomGenerator {\n",
    "    /* Draw a random int in [0, n] according to n sampling weights. */\n",
    "\n",
    "    private List<Integer> population;\n",
    "    private List<Double> samplingWeights;\n",
    "    private List<Integer> candidates;\n",
    "    private List<org.apache.commons.math3.util.Pair<Integer, Double>> pmf;\n",
    "    private int i;\n",
    "\n",
    "    public RandomGenerator(List<Double> samplingWeights) {\n",
    "        this.population =\n",
    "                IntStream.range(0, samplingWeights.size()).boxed().collect(Collectors.toList());\n",
    "        this.samplingWeights = samplingWeights;\n",
    "        this.candidates = new ArrayList<>();\n",
    "        this.i = 0;\n",
    "\n",
    "        this.pmf = new ArrayList<>();\n",
    "        for (int i = 0; i < samplingWeights.size(); i++) {\n",
    "            this.pmf.add(new org.apache.commons.math3.util.Pair(this.population.get(i), this.samplingWeights.get(i).doubleValue()));\n",
    "        }\n",
    "    }\n",
    "\n",
    "    public Integer draw() {\n",
    "        if (this.i == this.candidates.size()) {\n",
    "            this.candidates =\n",
    "                    Arrays.asList((Integer[]) new EnumeratedDistribution(this.pmf).sample(10000, new Integer[] {}));\n",
    "            this.i = 0;\n",
    "        }\n",
    "        this.i += 1;\n",
    "        return this.candidates.get(this.i - 1);\n",
    "    }\n",
    "}\n",
    "\n",
    "RandomGenerator generator =\n",
    "        new RandomGenerator(Arrays.asList(new Double[] {2.0, 3.0, 4.0}));\n",
    "Integer[] generatorOutput = new Integer[10];\n",
    "for (int i = 0; i < 10; i++) {\n",
    "    generatorOutput[i] = generator.draw();\n",
    "}\n",
    "System.out.println(Arrays.toString(generatorOutput));"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static ArrayList<ArrayList<Integer>> getNegatives(\n",
    "        ArrayList<ArrayList<Integer>> allContexts, Integer[][] corpus, int K) {\n",
    "    LinkedHashMap<?, Integer> counter = Vocab.countCorpus2D(corpus);\n",
    "    ArrayList<Double> samplingWeights = new ArrayList<>();\n",
    "    for (Map.Entry<?, Integer> entry : counter.entrySet()) {\n",
    "        samplingWeights.add(Math.pow(entry.getValue(), .75));\n",
    "    }\n",
    "    ArrayList<ArrayList<Integer>> allNegatives = new ArrayList<>();\n",
    "    RandomGenerator generator = new RandomGenerator(samplingWeights);\n",
    "    for (ArrayList<Integer> contexts : allContexts) {\n",
    "        ArrayList<Integer> negatives = new ArrayList<>();\n",
    "        while (negatives.size() < contexts.size() * K) {\n",
    "            Integer neg = generator.draw();\n",
    "            // Noise words cannot be context words\n",
    "            if (!contexts.contains(neg)) {\n",
    "                negatives.add(neg);\n",
    "            }\n",
    "        }\n",
    "        allNegatives.add(negatives);\n",
    "    }\n",
    "    return allNegatives;\n",
    "}\n",
    "\n",
    "ArrayList<ArrayList<Integer>> allNegatives = getNegatives(allContexts, corpus, 5);"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 26
   },
   "source": [
    "### Reading into Batches\n",
    "\n",
    "We extract all central target words `allCenters`, and the context words `allContexts` and noise words `allNegatives` of each central target word from the dataset. We will read them in random minibatches.\n",
    "\n",
    "In a minibatch of data, the $i^\\mathrm{th}$ example includes a central word and its corresponding $n_i$ context words and $m_i$ noise words. Since the context window size of each example may be different, the sum of context words and noise words, $n_i+m_i$, will be different. When constructing a minibatch, we concatenate the context words and noise words of each example, and add 0s for padding until the length of the concatenations are the same, that is, the length of all concatenations is $\\max_i n_i+m_i$(`maxLen`). In order to avoid the effect of padding on the loss function calculation, we construct the mask variable `masks`, each element of which corresponds to an element in the concatenation of context and noise words, `contextsNegatives`. When an element in the variable `contextsNegatives` is a padding, the element in the mask variable `masks` at the same position will be 0. Otherwise, it takes the value 1. In order to distinguish between positive and negative examples, we also need to distinguish the context words from the noise words in the `contextsNegatives` variable. Based on the construction of the mask variable, we only need to create a label variable `labels` with the same shape as the `contextsNegatives` variable and set the elements corresponding to context words (positive examples) to 1, and the rest to 0.\n",
    "\n",
    "Next, we will implement the minibatch reading function `batchifyData`. Its minibatch input `data` is a list of `NDArrays`, each element of which contains central target words `center`, context words `context`, and noise words `negative`. The minibatch data returned by this function conforms to the format we need, for example, it includes the mask variable.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static NDList batchifyData(NDList[] data) {\n",
    "    NDList centers = new NDList();\n",
    "    NDList contextsNegatives = new NDList();\n",
    "    NDList masks = new NDList();\n",
    "    NDList labels = new NDList();\n",
    "\n",
    "    long maxLen = 0;\n",
    "    for (NDList ndList : data) { // center, context, negative = ndList\n",
    "        maxLen =\n",
    "                Math.max(\n",
    "                        maxLen,\n",
    "                        ndList.get(1).countNonzero().getLong()\n",
    "                                + ndList.get(2).countNonzero().getLong());\n",
    "    }\n",
    "    for (NDList ndList : data) { // center, context, negative = ndList\n",
    "        NDArray center = ndList.get(0);\n",
    "        NDArray context = ndList.get(1);\n",
    "        NDArray negative = ndList.get(2);\n",
    "\n",
    "        int count = 0;\n",
    "        for (int i = 0; i < context.size(); i++) {\n",
    "            // If a 0 is found, we want to stop adding these\n",
    "            // values to NDArray\n",
    "            if (context.get(i).getInt() == 0) {\n",
    "                break;\n",
    "            }\n",
    "            contextsNegatives.add(context.get(i).reshape(1));\n",
    "            masks.add(manager.create(1).reshape(1));\n",
    "            labels.add(manager.create(1).reshape(1));\n",
    "            count += 1;\n",
    "        }\n",
    "        for (int i = 0; i < negative.size(); i++) {\n",
    "            // If a 0 is found, we want to stop adding these\n",
    "            // values to NDArray\n",
    "            if (negative.get(i).getInt() == 0) {\n",
    "                break;\n",
    "            }\n",
    "            contextsNegatives.add(negative.get(i).reshape(1));\n",
    "            masks.add(manager.create(1).reshape(1));\n",
    "            labels.add(manager.create(0).reshape(1));\n",
    "            count += 1;\n",
    "        }\n",
    "        // Fill with zeroes remaining array\n",
    "        while (count != maxLen) {\n",
    "            contextsNegatives.add(manager.create(0).reshape(1));\n",
    "            masks.add(manager.create(0).reshape(1));\n",
    "            labels.add(manager.create(0).reshape(1));\n",
    "            count += 1;\n",
    "        }\n",
    "\n",
    "        // Add this NDArrays to output NDArrays\n",
    "        centers.add(center.reshape(1));\n",
    "    }\n",
    "    return new NDList(\n",
    "            NDArrays.concat(centers).reshape(data.length, -1),\n",
    "            NDArrays.concat(contextsNegatives).reshape(data.length, -1),\n",
    "            NDArrays.concat(masks).reshape(data.length, -1),\n",
    "            NDArrays.concat(labels).reshape(data.length, -1));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 28
   },
   "source": [
    "Construct two simple examples:\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "NDList x1 =\n",
    "        new NDList(\n",
    "                manager.create(new int[] {1}),\n",
    "                manager.create(new int[] {2, 2}),\n",
    "                manager.create(new int[] {3, 3, 3, 3}));\n",
    "NDList x2 =\n",
    "        new NDList(\n",
    "                manager.create(new int[] {1}),\n",
    "                manager.create(new int[] {2, 2, 2}),\n",
    "                manager.create(new int[] {3, 3}));\n",
    "\n",
    "NDList batchedData = batchifyData(new NDList[] {x1, x2});\n",
    "String[] names = new String[] {\"centers\", \"contexts_negatives\", \"masks\", \"labels\"};\n",
    "for (int i = 0; i < batchedData.size(); i++) {\n",
    "    System.out.println(names[i] + \" shape: \" + batchedData.get(i));\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 30
   },
   "source": [
    "We use the `batchifyData` function just defined to specify the minibatch reading method for the `ArrayDataset` instance iterator.\n",
    "\n",
    "## Putting All Things Together\n",
    "\n",
    "Last, we define the `loadDataPTB` function that read the PTB dataset and return the dataset. In addition, we will create a function called `convertNDArray` that will convert the `centers`, `contexts`, and `negatives` lists, into `NDArrays` by putting 0s where there is no data in order for the rows to have the same lenghts.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "public static NDList convertNDArray(Object[] data, NDManager manager) {\n",
    "    ArrayList<Integer> centers = (ArrayList<Integer>) data[0];\n",
    "    ArrayList<ArrayList<Integer>> contexts = (ArrayList<ArrayList<Integer>>) data[1];\n",
    "    ArrayList<ArrayList<Integer>> negatives = (ArrayList<ArrayList<Integer>>) data[2];\n",
    "\n",
    "    // Create centers NDArray\n",
    "    NDArray centersNDArray = manager.create(centers.stream().mapToInt(i -> i).toArray());\n",
    "\n",
    "    // Create contexts NDArray\n",
    "    int maxLen = 0;\n",
    "    for (ArrayList<Integer> context : contexts) {\n",
    "        maxLen = Math.max(maxLen, context.size());\n",
    "    }\n",
    "    // Fill arrays with 0s to all have same lengths and be able to create NDArray\n",
    "    for (ArrayList<Integer> context : contexts) {\n",
    "        while (context.size() != maxLen) {\n",
    "            context.add(0);\n",
    "        }\n",
    "    }\n",
    "    NDArray contextsNDArray =\n",
    "            manager.create(\n",
    "                    contexts.stream()\n",
    "                            .map(u -> u.stream().mapToInt(i -> i).toArray())\n",
    "                            .toArray(int[][]::new));\n",
    "\n",
    "    // Create negatives NDArray\n",
    "    maxLen = 0;\n",
    "    for (ArrayList<Integer> negative : negatives) {\n",
    "        maxLen = Math.max(maxLen, negative.size());\n",
    "    }\n",
    "    // Fill arrays with 0s to all have same lengths and be able to create NDArray\n",
    "    for (ArrayList<Integer> negative : negatives) {\n",
    "        while (negative.size() != maxLen) {\n",
    "            negative.add(0);\n",
    "        }\n",
    "    }\n",
    "    NDArray negativesNDArray =\n",
    "            manager.create(\n",
    "                    negatives.stream()\n",
    "                            .map(u -> u.stream().mapToInt(i -> i).toArray())\n",
    "                            .toArray(int[][]::new));\n",
    "\n",
    "    return new NDList(centersNDArray, contextsNDArray, negativesNDArray);\n",
    "}\n",
    "\n",
    "public static Pair<ArrayDataset, Vocab> loadDataPTB(\n",
    "        int batchSize, int maxWindowSize, int numNoiseWords, NDManager manager)\n",
    "        throws IOException, TranslateException {\n",
    "    String[][] sentences = readPTB();\n",
    "    Vocab vocab = new Vocab(sentences, 10, new String[] {});\n",
    "    String[][] subSampled = subSampling(sentences, vocab);\n",
    "    Integer[][] corpus = new Integer[subSampled.length][];\n",
    "    for (int i = 0; i < subSampled.length; i++) {\n",
    "        corpus[i] = vocab.getIdxs(subSampled[i]);\n",
    "    }\n",
    "    Pair<ArrayList<Integer>, ArrayList<ArrayList<Integer>>> pair =\n",
    "            getCentersAndContext(corpus, maxWindowSize);\n",
    "    ArrayList<ArrayList<Integer>> negatives =\n",
    "            getNegatives(pair.getValue(), corpus, numNoiseWords);\n",
    "\n",
    "    NDList ndArrays =\n",
    "            convertNDArray(new Object[] {pair.getKey(), pair.getValue(), negatives}, manager);\n",
    "    ArrayDataset dataset =\n",
    "            new ArrayDataset.Builder()\n",
    "                    .setData(ndArrays.get(0), ndArrays.get(1), ndArrays.get(2))\n",
    "                    .optDataBatchifier(\n",
    "                            new Batchifier() {\n",
    "                                @Override\n",
    "                                public NDList batchify(NDList[] ndLists) {\n",
    "                                    return batchifyData(ndLists);\n",
    "                                }\n",
    "\n",
    "                                @Override\n",
    "                                public NDList[] unbatchify(NDList ndList) {\n",
    "                                    return new NDList[0];\n",
    "                                }\n",
    "                            })\n",
    "                    .setSampling(batchSize, true)\n",
    "                    .build();\n",
    "\n",
    "    return new Pair<>(dataset, vocab);\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 33
   },
   "source": [
    "Let us print the first minibatch of the dataset.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "Pair<ArrayDataset, Vocab> datasetVocab = loadDataPTB(512, 5, 5, manager);\n",
    "ArrayDataset dataset = datasetVocab.getKey();\n",
    "vocab = datasetVocab.getValue();\n",
    "\n",
    "Batch batch = dataset.getData(manager).iterator().next();\n",
    "for (int i = 0; i < batch.getData().size(); i++) {\n",
    "    System.out.println(names[i] + \" shape: \" + batch.getData().get(i).getShape());\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "origin_pos": 35
   },
   "source": [
    "## Summary\n",
    "\n",
    "* Subsampling attempts to minimize the impact of high-frequency words on the training of a word embedding model.\n",
    "* We can pad examples of different lengths to create minibatches with examples of all the same length and use mask variables to distinguish between padding and non-padding elements, so that only non-padding elements participate in the calculation of the loss function.\n",
    "\n",
    "## Exercises\n",
    "\n",
    "1. We use the `batchifyData` function to specify the minibatch reading method for the `ArrayDataset` instance iterator and print the shape of each variable in the first batch read. How should these shapes be calculated?\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Java",
   "language": "java",
   "name": "java"
  },
  "language_info": {
   "codemirror_mode": "java",
   "file_extension": ".jshell",
   "mimetype": "text/x-java-source",
   "name": "Java",
   "pygments_lexer": "java",
   "version": "14.0.2+12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
